{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from cache_diffusion import cachify\n",
    "from cache_diffusion.utils import (\n",
    "    PIXART_DEFAULT_CONFIG,\n",
    "    SD3_DEFAULT_CONFIG,\n",
    "    SDXL_DEFAULT_CONFIG,\n",
    "    SVD_DEFAULT_CONFIG,\n",
    ")\n",
    "from diffusers import (\n",
    "    DiffusionPipeline,\n",
    "    PixArtAlphaPipeline,\n",
    "    StableDiffusion3Pipeline,\n",
    "    StableVideoDiffusionPipeline,\n",
    ")\n",
    "from diffusers.utils import export_to_video, load_image, make_image_grid"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SDXL"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's load the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = DiffusionPipeline.from_pretrained(\n",
    "    \"stabilityai/stable-diffusion-xl-base-1.0\",\n",
    "    torch_dtype=torch.float16,\n",
    "    variant=\"fp16\",\n",
    "    use_safetensors=True,\n",
    ")\n",
    "pipe = pipe.to(\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_inference_steps = 20\n",
    "prompt = \"beautiful lady, (freckles), big smile, blue eyes, short hair, dark makeup, hyperdetailed photography, soft light, head and shoulders portrait, cover\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our pipeline requires just a single API call to perform caching.\n",
    "\n",
    "Let's disable the caching and run the baseline model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cachify.prepare(pipe, SDXL_DEFAULT_CONFIG)\n",
    "cachify.disable(pipe)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = torch.Generator(device=\"cuda\").manual_seed(2946901)\n",
    "baseline_img_20_steps = pipe(\n",
    "    prompt=prompt, num_inference_steps=num_inference_steps, generator=generator\n",
    ").images[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also reduce the number of steps to achieve similar latency as using cache diffusion. However, you will notice that the image quality is not as good."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = torch.Generator(device=\"cuda\").manual_seed(2946901)\n",
    "baseline_img_11_steps = pipe(prompt=prompt, num_inference_steps=11, generator=generator).images[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's enable the caching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cachify.enable(pipe)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = torch.Generator(device=\"cuda\").manual_seed(2946901)\n",
    "\n",
    "with cachify.infer(pipe) as cached_pipe:\n",
    "    cache_img = cached_pipe(\n",
    "        prompt=prompt, num_inference_steps=num_inference_steps, generator=generator\n",
    "    ).images[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "make_image_grid([baseline_img_20_steps, cache_img, baseline_img_11_steps], 1, 3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PixArt-Alpha"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = PixArtAlphaPipeline.from_pretrained(\n",
    "    \"PixArt-alpha/PixArt-XL-2-1024-MS\", torch_dtype=torch.float16\n",
    ")\n",
    "pipe = pipe.to(\"cuda\")\n",
    "num_inference_steps = 30\n",
    "prompt = \"a small cactus with a happy face in the Sahara desert\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cachify.prepare(pipe, PIXART_DEFAULT_CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = torch.Generator(device=\"cuda\").manual_seed(2946901)\n",
    "\n",
    "with cachify.infer(pipe) as cached_pipe:\n",
    "    img = cached_pipe(\n",
    "        prompt=prompt, generator=generator, num_inference_steps=num_inference_steps\n",
    "    ).images[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SVD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = StableVideoDiffusionPipeline.from_pretrained(\n",
    "    \"stabilityai/stable-video-diffusion-img2vid-xt\", torch_dtype=torch.float16, variant=\"fp16\"\n",
    ")\n",
    "pipe.enable_model_cpu_offload()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the conditioning image\n",
    "image = load_image(\n",
    "    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png\"\n",
    ")\n",
    "image = image.resize((1024, 576))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = torch.manual_seed(42)\n",
    "num_inference_steps = 25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cachify.prepare(pipe, SVD_DEFAULT_CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with cachify.infer(pipe) as cached_pipe:\n",
    "    frames = cached_pipe(image, decode_chunk_size=8, generator=generator).frames[0]\n",
    "\n",
    "export_to_video(frames, \"generated.mp4\", fps=7)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SD3-Medium"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = StableDiffusion3Pipeline.from_pretrained(\n",
    "    \"stabilityai/stable-diffusion-3-medium-diffusers\", torch_dtype=torch.float16\n",
    ")\n",
    "pipe = pipe.to(\"cuda\")\n",
    "num_inference_steps = 28"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cachify.prepare(pipe, SD3_DEFAULT_CONFIG)\n",
    "cachify.enable(pipe)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = torch.Generator(device=\"cuda\").manual_seed(2946901)\n",
    "with cachify.infer(pipe) as cached_pipe:\n",
    "    cached_img = pipe(\n",
    "        \"A cat holding a sign that says hello world\",\n",
    "        negative_prompt=\"\",\n",
    "        num_inference_steps=28,\n",
    "        guidance_scale=7.0,\n",
    "        generator=generator,\n",
    "    ).images[0]\n",
    "cached_img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ammo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
